#!/usr/bin/env python

import argparse
import glob
import logging
import os
import platform
import re
import sys
import time
from functools import partial
from threading import Lock

from tools.utils import CountDownLatch
from concurrent.futures import ThreadPoolExecutor

import config

# Import from parent directory
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:  # for backward-compatibility (in the main repository)
    import hmark.parseutility as parser
except ImportError:  # for subrepo
    import tools.parseutility as parser

# GLOBALS
originalDir = config.ROOT_PATH  # vuddy root directory
diffDir = os.path.join(originalDir, "diff")
# resultList = []
dummyFunction = parser.Function(None)
multimodeFlag = 0
debugMode = config.DEBUG_MODE

diffFileCnt = 0
diffFileCntLock = Lock()
functionCnt = 0
functionCntLock = Lock()

t1 = time.time()

""" re patterns """
PATTERN_SRC = re.compile(r'[\n](?=diff --git a/)')
PATTERN_CHUNK = re.compile(r'[\n](?=@@\s[^a-zA-Z]*\s[^a-zA-Z]*\s@@)')
PATTERN_LINE_NUM = re.compile(r'-(\d+,\d+) \+(\d+,\d+) ')
PATTERN_COMMIT = re.compile(r'[^a-zA-Z0-9]([a-zA-Z0-9]{40})[^a-zA-Z0-9]')

"""
Introduction: 根据fix的diff文件，通过一系列正则得到修改了哪些文件，将这些被修改了的文件cp临时一份，
接下来要做的就是确定这些修改的文件中是修改了哪些函数（这些函数就是漏洞函数），首先使用FuncParser-opt函数解析工具
解析那些文件得到文件中所有的函数，接下来就是定位哪些函数是漏洞函数，通过diff文件中标识的修改位置，同时匹配解析出来的函数
对应的文件中的行号来确定
"""


def init():
    # ARGUMENTS
    global repoName
    global multimodeFlag
    global total
    global debugMode
    global abstract_level

    parser_ = argparse.ArgumentParser()
    parser_.add_argument('REPO',
                         help='''Repository name''')
    parser_.add_argument('-m', '--multimode', action="store_true",
                         help='''Turn on Multimode''')
    parser_.add_argument('-d', '--debug', action="store_true", help=argparse.SUPPRESS)  # Hidden Debug Mode
    parser_.add_argument('-a', '--abstract-level', dest='abstract_level', type=int, nargs=1, choices=[0, 1, 2, 3, 4, 5],
                         default=5,
                         help='''Abstract Level''')
    args = parser_.parse_args()

    if args.REPO is None:
        parser_.print_help()
        exit()
    repoName = args.REPO  # name of the directory that holds DIFF patches
    if args.multimode:
        multimodeFlag = 1
    if args.debug:
        debugMode = True

    msg = "Retrieve vulnerable functions from {0}\nMulti-repo mode: ".format(repoName)
    if multimodeFlag:
        logging.info(msg + "On")
    else:
        logging.info(msg + "Off")

    # try making missing directories
    try:
        os.makedirs(os.path.join(originalDir, 'tmp'))
    except OSError as e:
        pass
    try:
        os.makedirs(os.path.join(originalDir, 'vul', repoName))
    except OSError as e:
        pass
    try:
        os.makedirs(os.path.join(originalDir, 'abs', repoName))
    except OSError as e:
        pass
    total = len(os.listdir(os.path.join(diffDir, repoName)))
    abstract_level = args.abstract_level


def source_from_cvepatch(diffFileName, cdl: CountDownLatch = None):  # diffFileName holds the filename of each DIFF patch
    # diffFileName looks like: CVE-2012-2372_7a9bc620049fed37a798f478c5699a11726b3d33.diff
    global repoName
    global debugMode
    global total
    global multimodeFlag
    global dummyFunction
    global diffDir
    global originalDir
    global diffFileCnt
    global functionCnt

    chunksCnt = 0  # number of DIFF patches

    with diffFileCntLock:
        logging.info(str(diffFileCnt + 1) + '/' + str(total))
        diffFileCnt += 1

    if os.path.getsize(os.path.join(diffDir, repoName, diffFileName)) > 1000000:
        # don't do anything with big DIFFs (merges, upgrades, ...).
        logging.info("[-]%s\t(file too large)" % diffFileName)
    else:
        commit_id = PATTERN_COMMIT.search(os.path.basename(diffFileName)).group(1)

        logging.info("[+]%s\t(proceed)" % diffFileName)
        with open(os.path.join(diffDir, repoName, diffFileName), 'r', encoding=config.ENCODING) as fp:
            patchLines = ''.join(fp.readlines())
            patchLinesSplitted = PATTERN_SRC.split(patchLines)
            commitLog = patchLinesSplitted[0]
            affectedFilesList = patchLinesSplitted[1:]

        repoPath = ''
        if multimodeFlag:  # multimode DIFFs have repoPath at the beginning.
            repoPath = commitLog.split('\n')[0].rstrip().lstrip("\xef\xbb\xbf")

        numAffectedFiles = len(affectedFilesList)
        for aidx, affectedFile in enumerate(affectedFilesList):
            if debugMode:
                logging.info("\tFile # " + str(aidx + 1) + '/' + str(numAffectedFiles))
            firstLine = affectedFile.split('\n')[0]  # git --diff a/path/filename.ext b/path/filename.ext
            affectedFileName = firstLine.split("--git ")[1].split(" ")[0].split("/")[-1]
            codePath = firstLine.split(' b')[1].strip()  # path/filename.ext

            if not parser.filename_ends_in(codePath):
                if debugMode:
                    logging.info("\t[-]", codePath, "(wrong extension)")
            else:
                secondLine = affectedFile.split('\n')[1]

                if secondLine.startswith("index") == 0:  # or secondLine.endswith("100644") == 0:
                    if debugMode:
                        logging.info("\t[-] " + codePath + " (invalid metadata)")  # we are looking for "index" only.
                else:
                    if debugMode:
                        logging.info("\t[+] %s" % codePath)
                    indexHashOld, indexHashNew = secondLine.split(' ')[1].split('..')

                    chunksList = PATTERN_CHUNK.split(affectedFile)[1:]  # diff file per chunk (in list)
                    chunksCnt += len(chunksList)

                    if multimodeFlag:
                        os.chdir(os.path.join(config.GIT_STORAGE_PATH, repoName, repoPath))
                    else:
                        os.chdir(os.path.join(config.GIT_STORAGE_PATH, repoName))

                    tmpOldFileName = os.path.join(originalDir, "tmp", "{0}_{1}_old".format(repoName, functionCnt))
                    command_show = "\"{0}\" show {1} > {2}".format(config.GIT_BIN, indexHashOld, tmpOldFileName)
                    os.system(command_show)

                    tmpNewFileName = os.path.join(originalDir, "tmp", "{0}_{1}_new".format(repoName, functionCnt))
                    command_show = "\"{0}\" show {1} > {2}".format(config.GIT_BIN, indexHashNew, tmpNewFileName)
                    os.system(command_show)

                    tmpOldFileName = os.path.join(originalDir, tmpOldFileName)
                    tmpNewFileName = os.path.join(originalDir, tmpNewFileName)
                    oldFunctionInstanceList = parser.DEFAULT_FUNC_PARSER.parse(tmpOldFileName, True)
                    newFunctionInstanceList = parser.DEFAULT_FUNC_PARSER.parse(tmpNewFileName, True)

                    finalOldFunctionList = []

                    numChunks = len(chunksList)
                    for ci, chunk in enumerate(chunksList):
                        if debugMode:
                            logging.info("\t\tChunk # " + str(ci + 1) + "/" + str(numChunks))

                        chunkSplitted = chunk.split('\n')
                        chunkFirstLine = chunkSplitted[0]
                        chunkLines = chunkSplitted[1:]

                        if debugMode:
                            logging.info(chunkFirstLine)
                        lineNums = PATTERN_LINE_NUM.search(chunkFirstLine)
                        oldLines = lineNums.group(1).split(',')

                        offset = int(oldLines[0])
                        pmList = []
                        lnList = []
                        for chunkLine in chunkSplitted[1:]:
                            if len(chunkLine) != 0:
                                pmList.append(chunkLine[0])

                        for i, pm in enumerate(pmList):
                            if pm == ' ' or pm == '-':
                                lnList.append(offset + i)
                            elif pm == '+':
                                lnList.append(offset + i - 1)
                                offset -= 1

                        """ HERE, ADD CHECK FOR NEW FUNCTIONS """
                        hitOldFunctionList = []
                        for f in oldFunctionInstanceList:
                            # logging.info(f.lines[0], f.lines[1]

                            for num in range(f.lines[0], f.lines[1] + 1):
                                if num in lnList:
                                    # logging.info("Hit at", num
                                    hitOldFunctionList.append(f)
                                    break  # found the function to be patched

                                    # if f.lines[0] <= offset <= f.lines[1]:
                                    #     logging.info("\t\t\tOffset HIT!!", f.name
                                    # elif f.lines[0] <= bound <= f.lines[1]:
                                    #     logging.info("\t\t\tBound  HIT!!", f.name

                        for f in hitOldFunctionList:
                            # logging.info("Verify hitFunction", f.name
                            # logging.info("ln",
                            for num in range(f.lines[0], f.lines[1] + 1):
                                # logging.info(num,
                                try:
                                    listIndex = lnList.index(num)
                                except ValueError:
                                    pass
                                else:
                                    if lnList.count(num) > 1:
                                        listIndex += 1
                                    # logging.info("\nmatch:", num
                                    # logging.info("value\t", chunkSplitted[1:][lnList.index(num)]
                                    # logging.info("pm   \t", pmList[lnList.index(num)]
                                    if pmList[listIndex] == '+' or pmList[listIndex] == '-':
                                        # logging.info("Maybe meaningful",
                                        flag = 0
                                        for commentKeyword in ["/*", "*/", "//", "*"]:
                                            if chunkLines[listIndex][1:].lstrip().startswith(commentKeyword):
                                                flag = 1
                                                break
                                        if flag:
                                            pass
                                            # logging.info("but not."
                                        else:
                                            # logging.info("MEANINGFUL!!"
                                            finalOldFunctionList.append(f)
                                            break
                                    else:
                                        pass
                                        # logging.info("Not meaningful"
                                        # logging.info("============\n"

                    finalOldFunctionList = list(set(finalOldFunctionList))  # sometimes list has dups

                    finalNewFunctionList = []
                    for fold in finalOldFunctionList:
                        flag = 0
                        for fnew in newFunctionInstanceList:
                            if fold.name == fnew.name:
                                finalNewFunctionList.append(fnew)
                                flag = 1
                                break
                        if not flag:
                            finalNewFunctionList.append(dummyFunction)

                    if debugMode:
                        logging.info("\t\t\t", len(finalNewFunctionList), "functions found.")
                    vulFileNameBase = diffFileName.split('.diff')[0] + '_' + affectedFileName

                    # os.chdir(os.path.join(originalDir, "vul", repoName))

                    for index, f in enumerate(finalOldFunctionList):
                        oldFuncInstance = finalOldFunctionList[index]

                        fp = open(os.path.join(originalDir, oldFuncInstance.parentFile), 'r', encoding=config.ENCODING)
                        try:
                            srcFileRaw = fp.readlines()
                        except Exception as e:
                            logging.exception(e)
                            continue
                        fp.close()
                        finalOldFunction = ''.join(srcFileRaw[oldFuncInstance.lines[0] - 1:oldFuncInstance.lines[1]])

                        # oldFuncArgs = ''
                        # for ai, funcArg in enumerate(oldFuncInstance.parameterList):
                        #     oldFuncArgs += "DTYPE " + funcArg
                        #     if ai + 1 != len(oldFuncInstance.parameterList):
                        #         oldFuncArgs += ', '
                        # finalOldFunction = "DTYPE {0} ({1})\n{{ {2}\n}}"\
                        #     .format(oldFuncInstance.name, oldFuncArgs, oldFuncInstance.funcBody)

                        finalOldFuncId = str(oldFuncInstance.funcId)

                        newFuncInstance = finalNewFunctionList[index]

                        if newFuncInstance.name is None:
                            finalNewFunction = ""
                        else:
                            fp = open(os.path.join(originalDir, newFuncInstance.parentFile), 'r', encoding=config.ENCODING)
                            srcFileRaw = fp.readlines()
                            fp.close()
                            finalNewFunction = ''.join(
                                srcFileRaw[newFuncInstance.lines[0] - 1:newFuncInstance.lines[1]])

                            # finalNewFunction = finalNewFunctionList[index].funcBody

                        finalOldBody = finalOldFunction[finalOldFunction.find('{') + 1:finalOldFunction.rfind('}')]
                        finalNewBody = finalNewFunction[finalNewFunction.find('{') + 1:finalNewFunction.rfind('}')]
                        tmpold = parser.normalize(parser.remove_comment(finalOldBody))
                        tmpnew = parser.normalize(parser.remove_comment(finalNewBody))

                        if tmpold != tmpnew and len(tmpnew) > 0:
                            # if two are same, it means nothing but comment is patched.
                            with functionCntLock:
                                functionCnt += 1
                            bash_path = os.path.join(originalDir, "vul", repoName)
                            vulOldFileName = vulFileNameBase + '_' + finalOldFuncId + "_OLD.vul"
                            vulNewFileName = vulFileNameBase + '_' + finalOldFuncId + "_NEW.vul"
                            vulOldFileName = os.path.join(bash_path, vulOldFileName)
                            vulNewFileName = os.path.join(bash_path, vulNewFileName)
                            with open(vulOldFileName, 'w', encoding=config.ENCODING) as fp:
                                fp.write(finalOldFunction)
                            with open(vulNewFileName, 'w', encoding=config.ENCODING) as fp:
                                if finalNewFunctionList[index].name is not None:
                                    fp.write(finalNewFunction)
                                else:
                                    fp.write("")

                            os.chdir(bash_path)
                            patch_file = '%s_%s.patch' % (vulFileNameBase, finalOldFuncId)
                            diffCommand = "\"{0}\" -u {1} {2} > {3}".format(config.DIFF_BIN,
                                                                            vulOldFileName,
                                                                            vulNewFileName,
                                                                            patch_file)
                            os.system(diffCommand)
                            # write abs info
                            oldFuncInstance = parser.DEFAULT_FUNC_PARSER.parse(vulOldFileName, False)[0]
                            newFuncInstance = parser.DEFAULT_FUNC_PARSER.parse(vulNewFileName, False)[0]
                            oldFuncInstance.parentFile = codePath

                            base_path_ = os.path.join("abs", repoName)
                            features_store_path = os.path.join(base_path_, os.path.basename(vulOldFileName) + '.features')
                            changes_store_path = os.path.join(base_path_, os.path.basename(vulOldFileName) + '.changes')

                            func_features = parser.extract_features(
                                oldFuncInstance,
                                abstract_level,
                                vuln_info=get_vuln_info(os.path.basename(vulOldFileName)),
                                repo_name=repoName,
                                fixed_commit_id=commit_id,
                                changes_info=changes_store_path.replace(os.sep, '/')
                            )
                            sedes = func_features.get_sedes()
                            func_features_serialized = sedes.serialize(func_features)

                            patch_diff_info = parser.extract_patch_diff_info(
                                patch_file,
                                parser.union_function_list_params(oldFuncInstance, newFuncInstance),
                                abstract_level
                            )
                            sedes = patch_diff_info.get_sedes()
                            patch_diff_info_serialized = sedes.serialize(patch_diff_info)

                            # extract func features
                            with open(os.path.join(originalDir, features_store_path), 'w', encoding=config.ENCODING) as fp:
                                # write func features first

                                fp.write(func_features_serialized)
                            # extract diff info
                            # includes add and del
                            with open(os.path.join(originalDir, changes_store_path), 'w', encoding=config.ENCODING) as fp:
                                # write add and del info
                                fp.write(patch_diff_info_serialized)
    if cdl:
        cdl.count_down()


def get_vuln_info(filename: str):
    return '_'.join(filename.split('_')[:-4])


def main():
    diffList = os.listdir(os.path.join(diffDir, repoName))
    if debugMode or "Windows" in platform.platform():
        # Windows - do not use multiprocessing
        # Using multiprocessing will lower performance
        for diffFile in diffList:
            source_from_cvepatch(diffFile)
    else:  # POSIX - use multiprocessing
        cdl = CountDownLatch(count=len(diffList))
        parallel_partial = partial(source_from_cvepatch, cdl=cdl)
        with ThreadPoolExecutor(max_workers=1) as pool:
            pool.map(parallel_partial, diffList)

    # delete temp source files
    wildcard_temp = os.path.join(originalDir, "tmp", repoName + "_*")
    for f in glob.glob(wildcard_temp):
        os.remove(f)

    logging.info("")
    logging.info("Done getting vulnerable functions from %s" % repoName)
    # logging.info("Reconstructed", len(
    logging.info("Reconstructed " + str(functionCnt) + " vulnerable functions from + " +
                 str(diffFileCnt) + " patches.")
    logging.info("Elapsed: %.2f sec" % (time.time() - t1))


if __name__ == "__main__":
    config.conf_log()
    init()
    main()
